# -*- coding: utf-8 -*-
# +
# # !rm -rf ./sbm-*/

# +
# # !rm -rf ./results/*

# +
# # !tar cvfz code.tar.gz ./*
# -

import numpy as np
from tqdm import tqdm
import argparse
import logging
import matplotlib.pyplot as plt
import seaborn as sns
import time

import torch
from torch_geometric.utils import subgraph

from data_loading import get_dataset
from data_utils import set_train_val_test_split, get_feature_mask, get_degree_based_feature_mask, get_group_mask, \
                get_random_walk, get_global_mean, get_normalized_degree
from models import get_model
from seeds import val_seeds
from evaluation import test
from train import train
from reconstruction import spatial_reconstruction
from mean_aggregation import MeanAggregation

# +
parser = argparse.ArgumentParser('GNN-Missing-Features')

# EXPERIMENT: dataset, add a new dataset?, do not use 'bail'
parser.add_argument('--dataset_name', 
                    type=str, 
                    help='Name of dataset', 
                    default="sbm", 
                    choices=["credit", "german", "bail",
                             "sbm"])

# EXPERIMENT: imputation type
parser.add_argument('--filling_method', type=str, help='Method to solve the missing feature problem', default="neighbor_mean", \
                    choices=["global_mean", "neighbor_mean", "feature_propagation", "graph_regularization"])
parser.add_argument('--epsilons', type=list, help='Epsilons', default=[0.00, 0.025, 0.05])

parser.add_argument('--models', type=list, help='Prediction models', default=["linear", "mlp", "gcn"])
parser.add_argument('--hidden_dim', type=int, help='Hidden dimension of model', default=64)
parser.add_argument('--dropout', type=float, help='Feature dropout', default=0.5)
parser.add_argument('--reconstruction_only', default=False)
parser.add_argument('--graph_sampling', help='Set if you want to use graph sampling (always true for large graphs)', action='store_true')

parser.add_argument("--patience", type=int, help="Patience for early stopping", default=200)
parser.add_argument('--lr', type=float, help='Learning Rate', default=0.005)
parser.add_argument('--epochs', type=int, help='Max number of epochs', default=10000)
parser.add_argument('--n_runs', type=int, help='Max number of runs', default=5)

parser.add_argument('--log', type=str, help='Log Level', default="WARNING", choices=["DEBUG", "INFO", "WARNING"])
parser.add_argument('--gpu_idx', type=int, help='Indexes of gpu to run program on', default=4)

# +
args = parser.parse_args("")

args.plot_over_iterations = True
if args.filling_method == "global_mean":
    args.beta = 0
    args.num_iterations = 1
elif args.filling_method == "neighbor_mean":
    args.beta = 0
    args.num_iterations = 1
elif args.filling_method == "feature_propagation":
    args.beta = 0
    args.num_iterations = 40
elif args.filling_method == "graph_regularization":
    args.beta = 0.25
    args.num_iterations = 40
else:
    raise Exception("Unknown filling method")

device = torch.device(f'cuda:{args.gpu_idx}' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

# MCAR = remove some of a node's features completely at random
# MAR = remove some of a node's features randomly conditioned on the group
args.missing_rates_0 = [0.0] + [i / 10 for i in range(1, 10, 2)]
args.missing_rates_1 = [0.0] + [i / 10 for i in range(1, 10, 2)]
args.group_sizes_0 = [i / 10 for i in range(1, 10, 2)]
args.group_sizes_1 = [i / 10 for i in range(1, 10, 2)]
args.intra_link_rates = [i / 10 for i in range(1, 10, 2)]
args.inter_link_rates = [i / 10 for i in range(1, 10, 2)]

# EXPERIMENT: importance sampling for missing nodes
missing_rate_deg = False
missing_rate_0 = 0.5
missing_rate_1 = 0.5
group_size_0 = 0.5
group_size_1 = 1.0 - group_size_0
intra_link = 0.5
inter_link = 0.5

filling_method_name = ' '.join([w.capitalize() for w in args.filling_method.split('_')])

# EXPERIMENT: synthetic dataset parameters
args.experiment_variable = 'Missing Feature Rate'
args.experiment_dim = len(args.missing_rates_0) 
args.experiment_list_0 = args.missing_rates_0
args.experiment_list_1 = args.missing_rates_1
args.axis_0 = 'Q {}'.format(args.experiment_variable)
args.axis_1 = 'R {}'.format(args.experiment_variable)

# args.experiment_variable = 'Group Size'
# args.experiment_dim = len(args.group_sizes_0) 
# args.experiment_list_0 = args.group_sizes_0
# args.experiment_list_1 = args.group_sizes_1
# args.axis_0 = 'Q {}'.format(args.experiment_variable)
# args.axis_1 = 'R {}'.format(args.experiment_variable)

# args.experiment_variable = 'Link Rate'
# args.experiment_dim = len(args.intra_link_rates) 
# args.experiment_list_0 = args.intra_link_rates
# args.experiment_list_1 = args.inter_link_rates
# args.axis_0 = 'Intra-Link Rate'
# args.axis_1 = 'Inter-Link Rate'

plt_reconstruction_errors = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))

plt_test_accs = {}
for model in args.models:
    plt_test_accs[model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))

plt_fair_reconstruction_errors = {}
plt_fair_test_accs = {}
for epsilon in args.epsilons:
    plt_fair_reconstruction_errors[epsilon] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))
    
    plt_fair_test_accs[epsilon] = {}
    for model in args.models:
        plt_fair_test_accs[epsilon][model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))
    
plt_discrimination_risks = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs, args.num_iterations))
plt_alphas = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))

plt_test_dps = {}
for model in args.models:
    plt_test_dps[model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))
    
plt_test_eos = {}
for model in args.models:
    plt_test_eos[model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))

plt_fair_discrimination_risks = {}
plt_fair_test_dps = {}
plt_fair_test_eos = {}
for epsilon in args.epsilons:
    plt_fair_discrimination_risks[epsilon] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs, args.num_iterations))
    
    plt_fair_test_dps[epsilon] = {}
    plt_fair_test_eos[epsilon] = {}
    for model in args.models:
        plt_fair_test_dps[epsilon][model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))
        plt_fair_test_eos[epsilon][model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))

# EXPERIMENT: adversarial recovery of sensitive group information
plt_sens_id_accs = {}
for model in args.models:
    plt_sens_id_accs[model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))
    
plt_fair_sens_id_accs = {}
for epsilon in args.epsilons:
    plt_fair_sens_id_accs[epsilon] = {}
    for model in args.models:
        plt_fair_sens_id_accs[epsilon][model] = torch.zeros((args.experiment_dim, args.experiment_dim, args.n_runs))

for idx_0, missing_rate_0 in enumerate(args.missing_rates_0):
    for idx_1, missing_rate_1 in enumerate(args.missing_rates_1):
# for idx_0, group_size_0 in enumerate(args.group_sizes_0):
#     for idx_1, group_size_1 in enumerate(args.group_sizes_1):
# for idx_0, intra_link in enumerate(args.intra_link_rates):
#     for idx_1, inter_link in enumerate(args.inter_link_rates):

        if args.experiment_variable == 'Group Size' and int(100 * group_size_0) != 100 - int(100 * group_size_1):            
            continue

        train_times = []

        for seed_idx, seed in tqdm(enumerate(val_seeds[:args.n_runs])):
            # Torch RNG
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            # Python RNG
            np.random.seed(seed)
            
            if args.dataset_name.startswith('sbm'):
                args.dataset_name = '-'.join(['sbm', str(int(100 * group_size_0)), str(int(100 * intra_link)), str(int(100 * inter_link))])
            dataset, sens, evaluator = get_dataset(name=args.dataset_name, seed=seed)

            train_mask = dataset.data.train_mask
            train_x = dataset.data.x[train_mask]
            train_edge_index, train_edge_attr = subgraph(train_mask, dataset.data.edge_index, \
                                        dataset.data.edge_attr.reshape(-1, 1), relabel_nodes=True, \
                                                         num_nodes=dataset.data.x.shape[0])
            train_edge_attr = train_edge_attr.reshape(-1)
            train_sens = sens[train_mask]
            
            train_n_nodes, n_features = train_x.shape
            num_classes = dataset.num_classes

            if missing_rate_deg:
                feature_mask = get_degree_based_feature_mask(train_edge_index, train_edge_attr, \
                                                n_nodes=train_n_nodes, n_features=n_features).to(device)
            else:
                feature_mask = get_feature_mask(rates=[missing_rate_0, missing_rate_1], group_mask=train_sens.bool(), \
                                                n_nodes=train_n_nodes, n_features=n_features).to(device)
            group_mask = train_sens.bool().to(device)

            for i in [1, 0]:
                for j in [1, 0]:
                    overlap_mask = feature_mask[group_mask == i] == j
                    j_str = 'K' if j == 1 else 'U'
                    print(i, j_str, overlap_mask.float().mean().item())

            data = dataset.data.to(device)
            train_x = train_x.to(device)
            train_edge_index = train_edge_index.to(device)
            train_edge_attr = train_edge_attr.to(device)
            train_start = time.time()

            x = train_x.clone()
            print("Starting feature filling")
            start = time.time()

            imputation_model = MeanAggregation(num_iterations=args.num_iterations)
            if args.filling_method == "global_mean":
                edge_index_list, edge_weight_list = get_global_mean(feature_mask, train_n_nodes)
                feature_transform = torch.ones(feature_mask.size(0)).to(feature_mask.device)
            elif args.filling_method == "neighbor_mean":
                # EXPERIMENT: possibly run this for more iterations?
                edge_index_list, edge_weight_list = get_random_walk(train_edge_index, train_edge_attr, train_n_nodes, n_features)
                feature_transform = torch.ones(feature_mask.size(0)).to(feature_mask.device)
            elif args.filling_method == "feature_propagation":
                edge_index_list, edge_weight_list = get_random_walk(train_edge_index, train_edge_attr, train_n_nodes, n_features)
                feature_transform = get_normalized_degree(train_edge_index, train_edge_attr, train_n_nodes)
            elif args.filling_method == "graph_regularization":
                edge_index_list, edge_weight_list = get_random_walk(train_edge_index, train_edge_attr,  train_n_nodes, n_features)
                feature_transform = get_normalized_degree(train_edge_index, train_edge_attr, train_n_nodes)

            filled_features, fair_filled_features, discrimination_risk, fair_discrimination_risk, risk_bounds, alphas \
                = imputation_model.impute(x=x, edge_index_list=edge_index_list, edge_weight_list=edge_weight_list, feature_transform=feature_transform, \
                                          beta=args.beta, feature_mask=feature_mask, group_mask=group_mask, epsilons=args.epsilons)

            print(f"Feature filling completed. It took: {time.time() - start:.2f}s")

            plt_discrimination_risks[idx_0, idx_1, seed_idx] = torch.tensor(discrimination_risk)
            for epsilon in args.epsilons:
                plt_fair_discrimination_risks[epsilon][idx_0, idx_1, seed_idx] = torch.tensor(fair_discrimination_risk[epsilon])
            plt_alphas[idx_0, idx_1, seed_idx] = alphas.max()

            plt_reconstruction_errors[idx_0, idx_1, seed_idx] = spatial_reconstruction(train_x, filled_features, feature_mask)
            for epsilon in args.epsilons:
                plt_fair_reconstruction_errors[epsilon][idx_0, idx_1, seed_idx] = spatial_reconstruction(train_x, fair_filled_features[epsilon], feature_mask)
            
            print(f'Reconstruction error: {plt_reconstruction_errors[idx_0, idx_1, seed_idx]}')
            for epsilon in args.epsilons:
                print(f'{epsilon}-fair reconstruction error: {plt_fair_reconstruction_errors[epsilon][idx_0, idx_1, seed_idx]}')

            if args.reconstruction_only:
                continue

            for model_type in args.models:
                if not args.dataset_name.startswith('sbm'):
                    model = get_model(model_name=model_type, num_features=data.num_features, num_classes=num_classes, args=args).to(device)
                    params = list(model.parameters())
                    optimizer = torch.optim.Adam(params, lr=args.lr)

                    test_acc = 0
                    val_accs = []

                    critereon = torch.nn.NLLLoss()

                    for epoch in range(0, args.epochs):
                        start = time.time()
                        x = data.x.clone()
                        x[data.train_mask] = filled_features

                        train(model, x, data, sens.to(device), optimizer, critereon, device=device)
                        (train_acc, val_acc, tmp_test_acc), (_, val_dp, test_dp), \
                        (_, val_eo, test_eo), out = test(model, x=x, data=data, sens=sens.bool().to(device), evaluator=evaluator, device=device)
                        if epoch == 0 or val_acc > max(val_accs):
                            test_acc = tmp_test_acc
                            y_soft = out.softmax(dim=-1)

                        val_accs.append(val_acc)
                        if epoch > args.patience and max(val_accs[-args.patience:]) <= max(val_accs[:-args.patience]):
                            # print(f"Epoch {epoch + 1} - {model_type} Train acc: {train_acc:.3f}, {model_type} Val acc: {val_acc:.3f}, {model_type} Test acc: {tmp_test_acc:.3f}. It took {time.time() - start:.2f}s")
                            break

                    (_, val_acc, test_acc), (_, val_dp, test_dp), (_, val_eo, test_eo), _ \
                        = test(model, x=x, data=data, sens=sens.bool().to(device), logits=y_soft, evaluator=evaluator)    
                    train_times.append(time.time() - train_start)

                    plt_test_accs[model_type][idx_0, idx_1, seed_idx] = test_acc
                    plt_test_dps[model_type][idx_0, idx_1, seed_idx] = test_dp
                    plt_test_eos[model_type][idx_0, idx_1, seed_idx] = test_eo
                
                ##
                
                model = get_model(model_name=model_type, num_features=data.num_features, num_classes=num_classes, args=args).to(device)
                params = list(model.parameters())
                optimizer = torch.optim.Adam(params, lr=args.lr)
                
                test_acc = 0
                val_accs = []
                
                critereon = torch.nn.NLLLoss()
                
                for epoch in range(0, args.epochs):
                    start = time.time()
                    x = data.x.clone()
                    x[data.train_mask] = filled_features
                    
                    train(model, x, data, sens.to(device), optimizer, critereon, device=device, use_sens_as_labels=True)
                    (train_acc, val_acc, tmp_test_acc), (_, val_dp, test_dp), \
                    (_, val_eo, test_eo), out = test(model, x=x, data=data, sens=sens.bool().to(device), evaluator=evaluator, device=device, use_sens_as_labels=True)
                    if epoch == 0 or val_acc > max(val_accs):
                        test_acc = tmp_test_acc
                        y_soft = out.softmax(dim=-1)

                    val_accs.append(val_acc)
                    if epoch > args.patience and max(val_accs[-args.patience:]) <= max(val_accs[:-args.patience]):
                        # print(f"Epoch {epoch + 1} - {model_type} Train acc: {train_acc:.3f}, {model_type} Val acc: {val_acc:.3f}, {model_type} Test acc: {tmp_test_acc:.3f}. It took {time.time() - start:.2f}s")
                        break
                        
                (_, val_acc, test_acc), (_, val_dp, test_dp), (_, val_eo, test_eo), _ \
                    = test(model, x=x, data=data, sens=sens.bool().to(device), logits=y_soft, evaluator=evaluator, use_sens_as_labels=True)

                plt_sens_id_accs[model_type][idx_0, idx_1, seed_idx] = test_acc
                
            ##
                        
            for epsilon in args.epsilons:
                for model_type in args.models:
                    if not args.dataset_name.startswith('sbm'):
                        fair_model = get_model(model_name=model_type, num_features=data.num_features, num_classes=num_classes, args=args).to(device)
                        fair_params = list(fair_model.parameters())
                        fair_optimizer = torch.optim.Adam(fair_params, lr=args.lr)

                        critereon = torch.nn.NLLLoss()

                        fair_test_acc = 0
                        fair_val_accs = []

                        for epoch in range(0, args.epochs):
                            start = time.time()
                            fair_x = data.x.clone()
                            fair_x[data.train_mask] = fair_filled_features[epsilon]

                            train(fair_model, fair_x, data, sens.to(device), fair_optimizer, critereon, device=device)
                            (fair_train_acc, fair_val_acc, fair_tmp_test_acc), (_, fair_val_dp, fair_test_dp), (_, fair_val_eo, fair_test_eo), fair_out \
                                = test(fair_model, x=fair_x, data=data, sens=sens.bool().to(device), evaluator=evaluator, device=device)
                            if epoch == 0 or fair_val_acc > max(fair_val_accs):
                                fair_test_acc = fair_tmp_test_acc
                                fair_y_soft = fair_out.softmax(dim=-1)

                            fair_val_accs.append(fair_val_acc)
                            if epoch > args.patience and max(fair_val_accs[-args.patience:]) <= max(fair_val_accs[:-args.patience]):
                                # print(f"Epoch {epoch + 1} - {epsilon}-fair {model_type} Train acc: {fair_train_acc:.3f}, {epsilon}-fair {model_type} Val acc: {fair_val_acc:.3f}, {epsilon}-fair {model_type} Test acc: {fair_tmp_test_acc:.3f}. It took {time.time() - start:.2f}s")
                                break   

                        (_, fair_val_acc, fair_test_acc), (_, fair_val_dp, fair_test_dp), (_, fair_val_eo, fair_test_eo), _ \
                            = test(fair_model, x=fair_x, data=data, sens=sens.bool().to(device), logits=fair_y_soft, evaluator=evaluator)

                        plt_fair_test_accs[epsilon][model_type][idx_0, idx_1, seed_idx] = fair_test_acc
                        plt_fair_test_dps[epsilon][model_type][idx_0, idx_1, seed_idx] = fair_test_dp
                        plt_fair_test_eos[epsilon][model_type][idx_0, idx_1, seed_idx] = fair_test_eo
                    
                    ## 
                    
                    fair_model = get_model(model_name=model_type, num_features=data.num_features, num_classes=num_classes, args=args).to(device)
                    fair_params = list(fair_model.parameters())
                    fair_optimizer = torch.optim.Adam(fair_params, lr=args.lr)
                    
                    critereon = torch.nn.NLLLoss()

                    fair_test_acc = 0
                    fair_val_accs = []
                    
                    for epoch in range(0, args.epochs):
                        start = time.time()
                        fair_x = data.x.clone()
                        fair_x[data.train_mask] = fair_filled_features[epsilon]

                        train(fair_model, fair_x, data, sens.to(device), fair_optimizer, critereon, device=device, use_sens_as_labels=True)
                        (fair_train_acc, fair_val_acc, fair_tmp_test_acc), (_, fair_val_dp, fair_test_dp), (_, fair_val_eo, fair_test_eo), fair_out \
                            = test(fair_model, x=fair_x, data=data, sens=sens.bool().to(device), evaluator=evaluator, device=device, use_sens_as_labels=True)
                        if epoch == 0 or fair_val_acc > max(fair_val_accs):
                            fair_test_acc = fair_tmp_test_acc
                            fair_y_soft = fair_out.softmax(dim=-1)

                        fair_val_accs.append(fair_val_acc)
                        if epoch > args.patience and max(fair_val_accs[-args.patience:]) <= max(fair_val_accs[:-args.patience]):
                            # print(f"Epoch {epoch + 1} - {epsilon}-fair {model_type} Train acc: {fair_train_acc:.3f}, {epsilon}-fair {model_type} Val acc: {fair_val_acc:.3f}, {epsilon}-fair {model_type} Test acc: {fair_tmp_test_acc:.3f}. It took {time.time() - start:.2f}s")
                            break   
                            
                    (_, fair_val_acc, fair_test_acc), (_, fair_val_dp, fair_test_dp), (_, fair_val_eo, fair_test_eo),  _ \
                        = test(fair_model, x=fair_x, data=data, sens=sens.bool().to(device), logits=fair_y_soft, evaluator=evaluator, use_sens_as_labels=True)
                    
                    plt_fair_sens_id_accs[epsilon][model_type][idx_0, idx_1, seed_idx] = fair_test_acc
            
        relative_reconstruction_error_mean, relative_reconstruction_error_std = plt_reconstruction_errors[idx_0, idx_1].mean(), plt_reconstruction_errors[idx_0, idx_1].std()
        results = {"relative_reconstruction_error_mean": relative_reconstruction_error_mean, "relative_reconstruction_error_std": relative_reconstruction_error_std}
        print(f'Reconstruction error: {relative_reconstruction_error_mean} +- {relative_reconstruction_error_std}')

        for epsilon in args.epsilons:
            fair_relative_reconstruction_error_mean, fair_relative_reconstruction_error_std = plt_fair_reconstruction_errors[epsilon][idx_0, idx_1].mean(), plt_fair_reconstruction_errors[epsilon][idx_0, idx_1].std()
            results = {str(epsilon) + "_fair_relative_reconstruction_error_mean": fair_relative_reconstruction_error_mean, str(epsilon) + "_fair_relative_reconstruction_error_std": fair_relative_reconstruction_error_std}
            print(f'{epsilon}-fair reconstruction error: {fair_relative_reconstruction_error_mean} +- {fair_relative_reconstruction_error_std}')
        
        if not args.reconstruction_only:
            for model_type in args.models:
                test_acc_mean, test_acc_std = plt_test_accs[model_type][idx_0, idx_1].mean(), plt_test_accs[model_type][idx_0, idx_1].std()
                test_dp_mean, test_dp_std = plt_test_dps[model_type][idx_0, idx_1].mean(), plt_test_dps[model_type][idx_0, idx_1].std()
                test_eo_mean, test_eo_std = plt_test_eos[model_type][idx_0, idx_1].mean(), plt_test_eos[model_type][idx_0, idx_1].std()
                sens_id_acc_mean, sens_id_acc_std = plt_sens_id_accs[model_type][idx_0, idx_1].mean(), plt_sens_id_accs[model_type][idx_0, idx_1].std()
                
                print(f'{model_type} Test Accuracy: {test_acc_mean} +- {test_acc_std}')
                print(f'{model_type} Test Demographic Parity: {test_dp_mean} +- {test_dp_std}')
                print(f'{model_type} Test Equal Opportunity Parity: {test_eo_mean} +- {test_eo_std}')
                print(f'{model_type} Sensitive ID Accuracy: {sens_id_acc_mean} +- {sens_id_acc_std}')
                results = {**results, **{model_type + "_test_acc_mean": test_acc_mean, \
                                         model_type + "_test_acc_std": test_acc_std, \
                                         model_type + "_test_dp_mean": test_dp_mean, \
                                         model_type + "_test_dp_std": test_dp_std, \
                                         model_type + "_test_eo_mean": test_eo_mean, \
                                         model_type + "_test_eo_std": test_eo_std, \
                                         model_type + "_sens_id_acc_mean": sens_id_acc_mean, \
                                         model_type + "_sens_id_acc_std": sens_id_acc_std}}

            for epsilon in args.epsilons:
                for model_type in args.models:
                    fair_test_acc_mean, fair_test_acc_std = plt_fair_test_accs[epsilon][model_type][idx_0, idx_1].mean(), plt_fair_test_accs[epsilon][model_type][idx_0, idx_1].std()
                    fair_test_dp_mean, fair_test_dp_std = plt_fair_test_dps[epsilon][model_type][idx_0, idx_1].mean(), plt_fair_test_dps[epsilon][model_type][idx_0, idx_1].std()
                    fair_test_eo_mean, fair_test_eo_std = plt_fair_test_eos[epsilon][model_type][idx_0, idx_1].mean(), plt_fair_test_eos[epsilon][model_type][idx_0, idx_1].std()
                    fair_sens_id_acc_mean, fair_sens_id_acc_std = plt_fair_sens_id_accs[epsilon][model_type][idx_0, idx_1].mean(), plt_fair_sens_id_accs[epsilon][model_type][idx_0, idx_1].std()
                    
                    print(f'{model_type} {epsilon}-fair Test Accuracy: {fair_test_acc_mean} +- {fair_test_acc_std}')
                    print(f'{model_type} {epsilon}-fair Test Demographic Parity: {fair_test_dp_mean} +- {fair_test_dp_std}')
                    print(f'{model_type} {epsilon}-fair Test Equal Opportunity Parity: {fair_test_eo_mean} +- {fair_test_eo_std}')
                    print(f'{model_type} {epsilon}-fair Sensitive ID Accuracy: {fair_sens_id_acc_mean} +- {fair_sens_id_acc_std}')
                    results = {**results, **{model_type + "_" + str(epsilon) + "_fair_test_acc_mean": fair_test_acc_mean, \
                                             model_type + "_" + str(epsilon) + "_fair_test_acc_std": fair_test_acc_std, \
                                             model_type + "_" + str(epsilon) + "_fair_test_dp_mean": fair_test_dp_mean, \
                                             model_type + "_" + str(epsilon) + "_fair_test_dp_std": fair_test_dp_std, \
                                             model_type + "_" + str(epsilon) + "_fair_test_eo_mean": fair_test_eo_mean, \
                                             model_type + "_" + str(epsilon) + "_fair_test_eo_std": fair_test_eo_std, \
                                             model_type + "_fair_sens_id_acc_mean": fair_sens_id_acc_mean, \
                                             model_type + "_fair_sens_id_acc_std": fair_sens_id_acc_std}}
        print()
        print()
            
# -

# ## Save outputs

# +
filename_prefix = 'cache_out/{}-{}-{}-'.format(args.dataset_name, \
                                            args.filling_method, \
                                             '_'.join(args.experiment_variable.lower().split()))

torch.save(plt_reconstruction_errors, filename_prefix + 'reconstruction_errors.pt')
torch.save(plt_test_accs, filename_prefix + 'test_accs.pt')

torch.save(plt_fair_reconstruction_errors, filename_prefix + 'fair_reconstruction_errors.pt')
torch.save(plt_fair_test_accs, filename_prefix + 'fair_test_accs.pt')

torch.save(plt_discrimination_risks, filename_prefix + 'discrimination_risks.pt')
torch.save(plt_alphas, filename_prefix + 'alphas.pt')
torch.save(plt_test_dps, filename_prefix + 'test_dps.pt')
torch.save(plt_test_eos, filename_prefix + 'test_eos.pt')

torch.save(plt_fair_discrimination_risks, filename_prefix + 'fair_discrimination_risks.pt')
torch.save(plt_fair_test_dps, filename_prefix + 'fair_test_dps.pt')
torch.save(plt_fair_test_eos, filename_prefix + 'fair_test_eos.pt')

torch.save(plt_sens_id_accs, filename_prefix + 'sens_id_accs.pt')
torch.save(plt_fair_sens_id_accs, filename_prefix + 'fair_sens_id_accs.pt')
# -

# ## Plot results

# +
filename_prefix = 'cache_out/{}-{}-{}-'.format(args.dataset_name, \
                                            args.filling_method, \
                                             '_'.join(args.experiment_variable.lower().split()))

plt_reconstruction_errors = torch.load(filename_prefix + 'reconstruction_errors.pt')
plt_test_accs = torch.load(filename_prefix + 'test_accs.pt')

plt_fair_reconstruction_errors = torch.load(filename_prefix + 'fair_reconstruction_errors.pt')
plt_fair_test_accs = torch.load(filename_prefix + 'fair_test_accs.pt')

plt_discrimination_risks = torch.load(filename_prefix + 'discrimination_risks.pt')
plt_alphas = torch.load(filename_prefix + 'alphas.pt')
plt_test_dps = torch.load(filename_prefix + 'test_dps.pt')
plt_test_eos = torch.load(filename_prefix + 'test_eos.pt')

plt_fair_discrimination_risks = torch.load(filename_prefix + 'fair_discrimination_risks.pt')
plt_fair_test_dps = torch.load(filename_prefix + 'fair_test_dps.pt')
plt_fair_test_eos = torch.load(filename_prefix + 'fair_test_eos.pt')

plt_sens_id_accs = torch.load(filename_prefix + 'sens_id_accs.pt')
plt_fair_sens_id_accs = torch.load(filename_prefix + 'fair_sens_id_accs.pt')

# +
plt.rcParams["figure.figsize"] = (20,10)
plt.rcParams.update({'font.size': 16})

if args.plot_over_iterations and (args.filling_method == "feature_propagation" or args.filling_method == "graph_regularization"):
    idx_0 = 2
    idx_1 = 2
    
    fig, ax1 = plt.subplots()
    ax1.errorbar(range(1, args.num_iterations + 1), plt_discrimination_risks[idx_0, idx_1].mean(dim=-2), \
                 yerr=plt_discrimination_risks[idx_0, idx_1].std(dim=-2), label='True Risk', fmt='o')

    ax1.errorbar(range(1, args.num_iterations + 1), plt_fair_discrimination_risks[0.025][idx_0, idx_1].mean(dim=-2), \
                 yerr=plt_fair_discrimination_risks[0.025][idx_0, idx_1].std(dim=-2), label='0.025-Fair Risk', fmt='o')

    ax1.set_xlabel("Iterations")
    ax1.set_ylabel("Discrimination Risk")

    plt.title("Discrimination Risk of {}".format(filling_method_name))
    plt.legend()
    plt.savefig('results/{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                args.filling_method, \
                                                 '_'.join(args.experiment_variable.lower().split()), \
                                                'risk_over_iterations'))
# -

plt_reconstruction_errors[0, 0] = 0
for epsilon in args.epsilons:
    plt_fair_reconstruction_errors[epsilon][0, 0] = 0

# +
plt_reconstruction_error_means = plt_reconstruction_errors.mean(dim=-1)
plt_reconstruction_error_stds = plt_reconstruction_errors.std(dim=-1)

plt_fair_reconstruction_error_means = {}
plt_fair_reconstruction_error_stds = {}
for epsilon in args.epsilons:
    plt_fair_reconstruction_error_means[epsilon] = plt_fair_reconstruction_errors[epsilon].mean(dim=-1)
    plt_fair_reconstruction_error_stds[epsilon] = plt_fair_reconstruction_errors[epsilon].std(dim=-1)

plt_test_acc_means = {}
plt_test_acc_stds = {}
for model in args.models:
    plt_test_acc_means[model] = plt_test_accs[model].mean(dim=-1)
    plt_test_acc_stds[model] = plt_test_accs[model].std(dim=-1)

plt_fair_test_acc_means = {}
plt_fair_test_acc_stds = {}
for epsilon in args.epsilons:
    plt_fair_test_acc_means[epsilon] = {}
    plt_fair_test_acc_stds[epsilon] = {}
    for model in args.models:
        plt_fair_test_acc_means[epsilon][model] = plt_fair_test_accs[epsilon][model].mean(dim=-1)
        plt_fair_test_acc_stds[epsilon][model] = plt_fair_test_accs[epsilon][model].std(dim=-1)

plt_discrimination_risk_means = plt_discrimination_risks[:, :, :, -1].mean(dim=-1)
plt_discrimination_risk_stds = plt_discrimination_risks[:, :, :, -1].std(dim=-1)
plt_alpha_means = plt_alphas.mean(dim=-1)
plt_alpha_stds = plt_alphas.std(dim=-1)

plt_fair_discrimination_risk_means = {}
plt_fair_discrimination_risk_stds = {}
for epsilon in args.epsilons:
    plt_fair_discrimination_risk_means[epsilon] = plt_fair_discrimination_risks[epsilon][:, :, :, -1].mean(dim=-1)
    plt_fair_discrimination_risk_stds[epsilon] = plt_fair_discrimination_risks[epsilon][:, :, :, -1].std(dim=-1)

plt_test_dp_means = {}
plt_test_dp_stds = {}
for model in args.models:
    plt_test_dp_means[model] = plt_test_dps[model].mean(dim=-1)
    plt_test_dp_stds[model] = plt_test_dps[model].std(dim=-1)
    
plt_test_eo_means = {}
plt_test_eo_stds = {}
for model in args.models:
    plt_test_eo_means[model] = plt_test_eos[model].mean(dim=-1)
    plt_test_eo_stds[model] = plt_test_eos[model].std(dim=-1)
    
plt_fair_test_dp_means = {}
plt_fair_test_dp_stds = {}
plt_fair_test_eo_means = {}
plt_fair_test_eo_stds = {}
for epsilon in args.epsilons:
    plt_fair_test_dp_means[epsilon] = {}
    plt_fair_test_dp_stds[epsilon] = {}
    plt_fair_test_eo_means[epsilon] = {}
    plt_fair_test_eo_stds[epsilon] = {}
    for model in args.models:
        plt_fair_test_dp_means[epsilon][model] = plt_fair_test_dps[epsilon][model].mean(dim=-1)
        plt_fair_test_dp_stds[epsilon][model] = plt_fair_test_dps[epsilon][model].std(dim=-1)
        
        plt_fair_test_eo_means[epsilon][model] = plt_fair_test_eos[epsilon][model].mean(dim=-1)
        plt_fair_test_eo_stds[epsilon][model] = plt_fair_test_eos[epsilon][model].std(dim=-1)

plt_sens_id_acc_means = {}
plt_sens_id_acc_stds = {}
for model in args.models:
    plt_sens_id_acc_means[model] = plt_sens_id_accs[model].mean(dim=-1)
    plt_sens_id_acc_stds[model] = plt_sens_id_accs[model].std(dim=-1)

plt_fair_sens_id_acc_means = {}
plt_fair_sens_id_acc_stds = {}
for epsilon in args.epsilons:
    plt_fair_sens_id_acc_means[epsilon] = {}
    plt_fair_sens_id_acc_stds[epsilon] = {}
    for model in args.models:
        plt_fair_sens_id_acc_means[epsilon][model] = plt_fair_sens_id_accs[epsilon][model].mean(dim=-1)
        plt_fair_sens_id_acc_stds[epsilon][model] = plt_fair_sens_id_accs[epsilon][model].std(dim=-1)
# -

# ## Missing Feature Rate

if args.experiment_variable == 'Missing Feature Rate':
    print(f'Aggregate reconstruction error: {plt_reconstruction_errors[1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_reconstruction_errors[1:, 1:].mean(dim=(0, 1)).std()}')
    for epsilon in args.epsilons:
        print(f'Aggregate {epsilon}-fair reconstruction error: {plt_fair_reconstruction_errors[epsilon][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_fair_reconstruction_errors[epsilon][1:, 1:].mean(dim=(0, 1)).std()}')

# +
plt.rcParams["figure.figsize"] = (20,20)
plt.rcParams.update({'font.size': 16})

if args.experiment_variable == 'Missing Feature Rate':
    fig, axes = plt.subplots(2, 2)
    
    vmin = plt_reconstruction_error_means.min().item()
    vmax = plt_reconstruction_error_means.max().item()
    for epsilon in args.epsilons:
        vmin = min(vmin, (plt_fair_reconstruction_error_means[epsilon]).min().item())
        vmax = max(vmax, (plt_fair_reconstruction_error_means[epsilon]).max().item())
    
    for idx, (epsilon, rec_error_means, rec_error_stds) in enumerate([(-1, plt_reconstruction_error_means, plt_reconstruction_error_stds)] \
                                                  + [(epsilon, plt_fair_reconstruction_error_means[epsilon], plt_fair_reconstruction_error_stds[epsilon]) for epsilon in args.epsilons]):

        labels = []
        for idx_1 in range(rec_error_means.size(1)):
            labels.append([])
            for idx_0 in range(rec_error_stds.size(0)):
                labels[-1].append("{:.3f}".format(rec_error_means[idx_0, idx_1].item()) \
                                  + '\n± ' \
                                  + "{:.3f}".format(rec_error_stds[idx_0, idx_1].item()))
        labels = np.array(labels)

        sns.heatmap(rec_error_means.t(), annot=labels, fmt='', vmin=vmin, vmax=vmax, ax=axes[idx // 2, idx % 2])
        axes[idx // 2, idx % 2].invert_yaxis()
        if epsilon == -1:
            axes[idx // 2, idx % 2].set_title("Regular " + filling_method_name)
        else:
            axes[idx // 2, idx % 2].set_title(str(epsilon) + "-Fair " + filling_method_name)
        axes[idx // 2, idx % 2].set_xticklabels(args.experiment_list_0)
        axes[idx // 2, idx % 2].set_yticklabels(args.experiment_list_1)
        axes[idx // 2, idx % 2].set_xlabel(args.axis_0)
        axes[idx // 2, idx % 2].set_ylabel(args.axis_1)
        
    plt.suptitle('Reconstruction Error') 

    plt.tight_layout()
    plt.savefig('results/{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                args.filling_method, \
                                                 '_'.join(args.experiment_variable.lower().split()), \
                                                'reconstruction_error')) 
# -

if args.experiment_variable == 'Missing Feature Rate':
    for model in args.models:
        print(f'Aggregate test accuracy of {model}: {plt_test_accs[model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_test_accs[model][1:, 1:].mean(dim=(0, 1)).std()}')
    for epsilon in args.epsilons:
        for model in args.models:
            print(f'Aggregate {epsilon}-fair test accuracy of {model}: {plt_fair_test_accs[epsilon][model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_fair_test_accs[epsilon][model][1:, 1:].mean(dim=(0, 1)).std()}')

# +
plt.rcParams["figure.figsize"] = (20,20)
plt.rcParams.update({'font.size': 16})

if args.experiment_variable == 'Missing Feature Rate':
    for model_type in args.models:
        fig, axes = plt.subplots(2, 2)
        
        vmin = plt_test_acc_means[model_type].min().item()
        vmax = plt_test_acc_means[model_type].max().item()
        for epsilon in args.epsilons:
            vmin = min(vmin, (plt_fair_test_acc_means[epsilon][model_type]).min().item())
            vmax = max(vmax, (plt_fair_test_acc_means[epsilon][model_type]).max().item())

        for idx, (epsilon, acc_means, acc_stds) in enumerate([(-1, plt_test_acc_means[model_type], plt_test_acc_stds[model_type])] \
                                                  + [(epsilon, plt_fair_test_acc_means[epsilon][model_type], plt_fair_test_acc_stds[epsilon][model_type]) for epsilon in args.epsilons]):    
            
            labels = []
            for idx_1 in range(acc_means.size(1)):
                labels.append([])
                for idx_0 in range(acc_means.size(0)):
                    labels[-1].append("{:.3f}".format(acc_means[idx_0, idx_1].item()) \
                                      + '\n± ' \
                                      + "{:.3f}".format(acc_stds[idx_0, idx_1].item()))
            labels = np.array(labels)

            sns.heatmap(acc_means.t(), annot=labels, fmt='', vmin=vmin, vmax=vmax, ax=axes[idx // 2, idx % 2])
            axes[idx // 2, idx % 2].invert_yaxis()
            if epsilon == -1:
                axes[idx // 2, idx % 2].set_title("Regular " + filling_method_name)
            else:
                axes[idx // 2, idx % 2].set_title(str(epsilon) + "-Fair " + filling_method_name)
            axes[idx // 2, idx % 2].set_xticklabels(args.experiment_list_0)
            axes[idx // 2, idx % 2].set_yticklabels(args.experiment_list_1)
            axes[idx // 2, idx % 2].set_xlabel(args.axis_0)
            axes[idx // 2, idx % 2].set_ylabel(args.axis_1)
        
        plt.suptitle('Test Accuracy of {} model'.format(model_type.upper()))
        plt.tight_layout()
        plt.savefig('results/{}-{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                    args.filling_method, \
                                                     '_'.join(args.experiment_variable.lower().split()), \
                                                     model_type, \
                                                    'test_accuracy'))  
# -

if args.experiment_variable == 'Missing Feature Rate':
    print(f'Aggregate discrimination risk: {plt_discrimination_risks[1:, 1:, :, -1].mean(dim=(0, 1)).mean()} ± {plt_discrimination_risks[1:, 1:, :, -1].mean(dim=(0, 1)).std()}')
    for epsilon in args.epsilons:
        print(f'Aggregate {epsilon}-fair discrimination risk: {plt_fair_discrimination_risks[epsilon][1:, 1:, :, -1].mean(dim=(0, 1)).mean()} ± {plt_fair_discrimination_risks[epsilon][1:, 1:, :, -1].mean(dim=(0, 1)).std()}')

# +
plt.rcParams["figure.figsize"] = (10,20)
plt.rcParams.update({'font.size': 16})

if args.experiment_variable == 'Missing Feature Rate':
    # fig, axes = plt.subplots(2, 2)
    fig, axes = plt.subplots(2)

    vmin = plt_discrimination_risk_means.min().item()
    vmax = plt_discrimination_risk_means.max().item()
#     for epsilon in args.epsilons:
#         if epsilon == 0:
#             continue
#         vmin = min(vmin, (plt_fair_discrimination_risk_means[epsilon]).min().item())
#         vmax = max(vmax, (plt_fair_discrimination_risk_means[epsilon]).max().item())
    
    for idx, (epsilon, risk_means, risk_stds) in enumerate([(-1, plt_discrimination_risk_means, plt_discrimination_risk_stds)] \
                                                           + [(-2, plt_alpha_means, plt_alpha_stds)]): # \
                                                 # + [(epsilon, plt_fair_discrimination_risk_means[epsilon], plt_fair_discrimination_risk_stds[epsilon]) for epsilon in args.epsilons if epsilon != 0]):
        labels = []
        for idx_1 in range(risk_means.size(1)):
            labels.append([])
            for idx_0 in range(risk_means.size(0)):
                labels[-1].append("{:.3f}".format(risk_means[idx_0, idx_1].item()) \
                                  + '\n± ' \
                                  + "{:.3f}".format(risk_stds[idx_0, idx_1].item()))
        labels = np.array(labels)

        # draw_ax = axes[idx // 2, idx % 2]
        draw_ax = axes[idx]
        if epsilon != -2:
            ax = sns.heatmap(risk_means.t(), annot=labels, fmt='', vmin=vmin, vmax=vmax, ax=draw_ax)
        else:
            ax = sns.heatmap(risk_means.t(), annot=labels, fmt='', ax=draw_ax)
        ax.invert_yaxis()
        if epsilon == -1:
            ax.set_title("Regular " + filling_method_name)
        elif epsilon == -2:
            ax.set_title('Max Alpha of {}'.format(filling_method_name))
        else:
            ax.set_title(str(epsilon) + "-Fair " + filling_method_name)
        ax.set_xticklabels(args.experiment_list_0)
        ax.set_yticklabels(args.experiment_list_1)
        ax.set_xlabel(args.axis_0)
        ax.set_ylabel(args.axis_1)
    
    # plt.suptitle('Discrimination Risk') 
    plt.tight_layout()
    plt.savefig('results/{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                         args.filling_method, \
                                         '_'.join(args.experiment_variable.lower().split()), \
                                        'discrimination_risk'))  
# -

if args.experiment_variable == 'Missing Feature Rate':
    for model in args.models:
        print(f'Aggregate statistical parity of {model}: {plt_test_dps[model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_test_dps[model][1:, 1:].mean(dim=(0, 1)).std()}')
    for epsilon in args.epsilons:
        for model in args.models:
            print(f'Aggregate {epsilon}-fair statistical parity of {model}: {plt_fair_test_dps[epsilon][model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_fair_test_dps[epsilon][model][1:, 1:].mean(dim=(0, 1)).std()}')

if args.experiment_variable == 'Missing Feature Rate':
    for model in args.models:
        print(f'Aggregate predictive parity of {model}: {plt_test_eos[model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_test_eos[model][1:, 1:].mean(dim=(0, 1)).std()}')
    for epsilon in args.epsilons:
        for model in args.models:
            print(f'Aggregate {epsilon}-fair predictive parity of {model}: {plt_fair_test_eos[epsilon][model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_fair_test_eos[epsilon][model][1:, 1:].mean(dim=(0, 1)).std()}')

# +
plt.rcParams["figure.figsize"] = (20,20)
plt.rcParams.update({'font.size': 16})

if args.experiment_variable == 'Missing Feature Rate':
    for model_type in args.models:
        fig, axes = plt.subplots(2, 2)
        
        vmin = plt_test_dp_means[model_type].min().item()
        vmax = plt_test_dp_means[model_type].max().item()
        for epsilon in args.epsilons:
            vmin = min(vmin, (plt_fair_test_dp_means[epsilon][model_type]).min().item())
            vmax = max(vmax, (plt_fair_test_dp_means[epsilon][model_type]).max().item())

        for idx, (epsilon, dp_means, dp_stds) in enumerate([(-1, plt_test_dp_means[model_type], plt_test_dp_stds[model_type])] \
                                                  + [(epsilon, plt_fair_test_dp_means[epsilon][model_type], plt_fair_test_dp_stds[epsilon][model_type]) for epsilon in args.epsilons]):
            labels = []
            for idx_1 in range(dp_means.size(1)):
                labels.append([])
                for idx_0 in range(dp_means.size(0)):
                    labels[-1].append("{:.3f}".format(dp_means[idx_0, idx_1].item()) \
                                      + '\n± ' \
                                      + "{:.3f}".format(dp_stds[idx_0, idx_1].item()))
            labels = np.array(labels)

            sns.heatmap(dp_means.t(), annot=labels, fmt='', vmin=vmin, vmax=vmax, ax=axes[idx // 2, idx % 2])
            axes[idx // 2, idx % 2].invert_yaxis()
            if epsilon == -1:
                axes[idx // 2, idx % 2].set_title("Regular " + filling_method_name)
            else:
                axes[idx // 2, idx % 2].set_title(str(epsilon) + "-Fair " + filling_method_name)
            axes[idx // 2, idx % 2].set_xticklabels(args.experiment_list_0)
            axes[idx // 2, idx % 2].set_yticklabels(args.experiment_list_1)
            axes[idx // 2, idx % 2].set_xlabel(args.axis_0)
            axes[idx // 2, idx % 2].set_ylabel(args.axis_1)
            
        plt.suptitle('Test Demographic Parity of {} model'.format(model_type.upper()))
        plt.tight_layout()
        plt.savefig('results/{}-{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                    args.filling_method, \
                                                     '_'.join(args.experiment_variable.lower().split()), \
                                                     model_type, \
                                                    'test_dp'))  
# -

if args.experiment_variable == 'Missing Feature Rate':
    for model in args.models:
        print(f'Aggregate sensitive group identification accuracy of {model}: {plt_sens_id_accs[model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_sens_id_accs[model][1:, 1:].mean(dim=(0, 1)).std()}')
    for epsilon in args.epsilons:
        for model in args.models:
            print(f'Aggregate {epsilon}-fair sensitive group identification accuracy of {model}: {plt_fair_sens_id_accs[epsilon][model][1:, 1:].mean(dim=(0, 1)).mean()} ± {plt_fair_sens_id_accs[epsilon][model][1:, 1:].mean(dim=(0, 1)).std()}')

# +
plt.rcParams["figure.figsize"] = (20,20)
plt.rcParams.update({'font.size': 16})

if args.experiment_variable == 'Missing Feature Rate':
    for model_type in args.models:
        fig, axes = plt.subplots(2, 2)
        
        vmin = plt_sens_id_acc_means[model_type].min().item()
        vmax = plt_sens_id_acc_means[model_type].max().item()
        for epsilon in args.epsilons:
            vmin = min(vmin, (plt_fair_sens_id_acc_means[epsilon][model_type]).min().item())
            vmax = max(vmax, (plt_fair_sens_id_acc_means[epsilon][model_type]).max().item())

        for idx, (epsilon, acc_means, acc_stds) in enumerate([(-1, plt_sens_id_acc_means[model_type], plt_sens_id_acc_stds[model_type])] \
                                                  + [(epsilon, plt_fair_sens_id_acc_means[epsilon][model_type], plt_fair_sens_id_acc_stds[epsilon][model_type]) for epsilon in args.epsilons]):    
            
            labels = []
            for idx_1 in range(acc_means.size(1)):
                labels.append([])
                for idx_0 in range(acc_means.size(0)):
                    labels[-1].append("{:.3f}".format(acc_means[idx_0, idx_1].item()) \
                                      + '\n± ' \
                                      + "{:.3f}".format(acc_stds[idx_0, idx_1].item()))
            labels = np.array(labels)

            sns.heatmap(acc_means.t(), annot=labels, fmt='', vmin=vmin, vmax=vmax, ax=axes[idx // 2, idx % 2])
            axes[idx // 2, idx % 2].invert_yaxis()
            if epsilon == -1:
                axes[idx // 2, idx % 2].set_title("Regular " + filling_method_name)
            else:
                axes[idx // 2, idx % 2].set_title(str(epsilon) + "-Fair " + filling_method_name)
            axes[idx // 2, idx % 2].set_xticklabels(args.experiment_list_0)
            axes[idx // 2, idx % 2].set_yticklabels(args.experiment_list_1)
            axes[idx // 2, idx % 2].set_xlabel(args.axis_0)
            axes[idx // 2, idx % 2].set_ylabel(args.axis_1)
        
        plt.suptitle('Sensitive Group Identification Accuracy of {} model'.format(model_type.upper()))
        plt.tight_layout()
        plt.savefig('results/{}-{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                    args.filling_method, \
                                                     '_'.join(args.experiment_variable.lower().split()), \
                                                     model_type, \
                                                    'sens_id_accuracy'))  
# -

# ## Group Size

if args.experiment_variable == 'Group Size':
    plt.figure()
    ax = plt.axes()
    
    for idx, (epsilon, rec_error_means, rec_error_stds) in enumerate([(-1, plt_reconstruction_error_means, plt_reconstruction_error_stds)] \
                                                  + [(epsilon, plt_fair_reconstruction_error_means[epsilon], plt_fair_reconstruction_error_stds[epsilon]) for epsilon in args.epsilons]):
        if epsilon == -1:
            label = "Regular " + filling_method_name
        else:
            label = str(epsilon) + "-Fair " + filling_method_name
        ax.errorbar(args.experiment_list_0, torch.sum(rec_error_means, dim=1), \
                             yerr=torch.sum(rec_error_stds, dim=1), label=label)
        
    ax.set_xlabel(args.axis_0)
    ax.set_ylabel("Reconstruction Error")
    plt.legend()
    plt.tight_layout()
    plt.savefig('results/{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                args.filling_method, \
                                                 '_'.join(args.experiment_variable.lower().split()), \
                                                'reconstruction_error'))

if args.experiment_variable == 'Group Size':
    
    for model_type in args.models:
        plt.figure()
        ax = plt.axes()
        
        for idx, (epsilon, acc_means, acc_stds) in enumerate([(-1, plt_test_acc_means[model_type], plt_test_acc_stds[model_type])] \
                                                  + [(epsilon, plt_fair_test_acc_means[epsilon][model_type], plt_fair_test_acc_stds[epsilon][model_type]) for epsilon in args.epsilons]):
            if epsilon == -1:
                label = "Regular " + filling_method_name
            else:
                label = str(epsilon) + "-Fair " + filling_method_name
            ax.errorbar(args.experiment_list_0, torch.sum(acc_means, dim=1), \
                                 yerr=torch.sum(acc_stds, dim=1), label=label)

        ax.set_xlabel(args.axis_0)
        ax.set_ylabel('Test Accuracy of {} model'.format(model_type.upper()))
        plt.legend()
        plt.tight_layout()
        plt.savefig('results/{}-{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                    args.filling_method, \
                                                     '_'.join(args.experiment_variable.lower().split()), \
                                                     model_type, \
                                                    'test_accuracy'))

if args.experiment_variable == 'Group Size':
    fig, axs = plt.subplots(2)
    
    for idx, (epsilon, risk_means, risk_stds) in enumerate([(-1, plt_discrimination_risk_means, plt_discrimination_risk_stds)] \
                                                  + [(epsilon, plt_fair_discrimination_risk_means[epsilon], plt_fair_discrimination_risk_stds[epsilon]) for epsilon in args.epsilons if epsilon != 0]):
        if epsilon == -1:
            label = "Regular " + filling_method_name
        else:
            label = str(epsilon) + "-Fair " + filling_method_name
        axs[0].errorbar(args.experiment_list_0, torch.sum(risk_means, dim=1), \
                             yerr=torch.sum(risk_stds, dim=1), label=label)
        
    axs[0].set_xlabel(args.axis_0)
    axs[0].set_ylabel('Discrimination Risk')
    axs[0].legend()
    
    axs[1].errorbar(args.experiment_list_0, torch.sum(plt_alpha_means, dim=1), \
                             yerr=torch.sum(plt_alpha_stds, dim=1))
    axs[1].set_xlabel(args.axis_0)
    axs[1].set_ylabel('Max Alpha')
    
    plt.tight_layout()
    plt.savefig('results/{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                         args.filling_method, \
                                         '_'.join(args.experiment_variable.lower().split()), \
                                        'discrimination_risk'))

if args.experiment_variable == 'Group Size':
    
    for model_type in args.models:
        plt.figure()
        ax = plt.axes()
        
        for idx, (epsilon, dp_means, dp_stds) in enumerate([(-1, plt_test_dp_means[model_type], plt_test_dp_stds[model_type])] \
                                                  + [(epsilon, plt_fair_test_dp_means[epsilon][model_type], plt_fair_test_dp_stds[epsilon][model_type]) for epsilon in args.epsilons]):
            if epsilon == -1:
                label = "Regular " + filling_method_name
            else:
                label = str(epsilon) + "-Fair " + filling_method_name
            ax.errorbar(args.experiment_list_0, torch.sum(dp_means, dim=1), \
                                 yerr=torch.sum(dp_stds, dim=1), label=label)

        ax.set_xlabel(args.axis_0)
        ax.set_ylabel('Test Demographic Parity of {} model'.format(model_type.upper()))
        plt.tight_layout()
        plt.legend()
        plt.savefig('results/{}-{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                    args.filling_method, \
                                                     '_'.join(args.experiment_variable.lower().split()), \
                                                     model_type, \
                                                    'test_dp'))

# ## Link Rate

if args.experiment_variable == 'Link Rate':
    plt.figure()
    ax = plt.axes()
    
    for idx, (epsilon, rec_error_means, rec_error_stds) in enumerate([(-1, plt_reconstruction_error_means, plt_reconstruction_error_stds)] \
                                                  + [(epsilon, plt_fair_reconstruction_error_means[epsilon], plt_fair_reconstruction_error_stds[epsilon]) for epsilon in args.epsilons]):
        if epsilon == -1:
            label = "Regular " + filling_method_name
        else:
            label = str(epsilon) + "-Fair " + filling_method_name
            
        new_x = []
        new_y = []
        new_yerr = []
        for idx_0, intra_link in enumerate(args.experiment_list_0):
            for idx_1, inter_link in enumerate(args.experiment_list_1):
                new_x.append(inter_link / intra_link)
                new_y.append(rec_error_means[idx_0, idx_1])
                new_yerr.append(rec_error_stds[idx_0, idx_1])
        
        order = np.argsort(new_x)
        new_x = np.array(new_x)[order]
        new_y = np.array(new_y)[order]
        new_yerr = np.array(new_yerr)[order]
        ax.errorbar(new_x, new_y, yerr=new_yerr, label=label)
        
    ax.set_xlabel(args.axis_1 + " / " + args.axis_0)
    ax.set_ylabel("Reconstruction Error")
    plt.legend()
    plt.tight_layout()
    plt.savefig('results/{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                args.filling_method, \
                                                 '_'.join(args.experiment_variable.lower().split()), \
                                                'reconstruction_error'))

if args.experiment_variable == 'Link Rate':
    
    for model_type in args.models:
        plt.figure()
        ax = plt.axes()
        
        for idx, (epsilon, acc_means, acc_stds) in enumerate([(-1, plt_test_acc_means[model_type], plt_test_acc_stds[model_type])] \
                                                  + [(epsilon, plt_fair_test_acc_means[epsilon][model_type], plt_fair_test_acc_stds[epsilon][model_type]) for epsilon in args.epsilons]):
            if epsilon == -1:
                label = "Regular " + filling_method_name
            else:
                label = str(epsilon) + "-Fair " + filling_method_name
            
            new_x = []
            new_y = []
            new_yerr = []
            for idx_0, intra_link in enumerate(args.experiment_list_0):
                for idx_1, inter_link in enumerate(args.experiment_list_1):
                    new_x.append(inter_link / intra_link)
                    new_y.append(acc_means[idx_0, idx_1])
                    new_yerr.append(acc_stds[idx_0, idx_1])

            order = np.argsort(new_x)
            new_x = np.array(new_x)[order]
            new_y = np.array(new_y)[order]
            new_yerr = np.array(new_yerr)[order]
            ax.errorbar(new_x, new_y, yerr=new_yerr, label=label)

        ax.set_xlabel(args.axis_1 + " / " + args.axis_0)
        ax.set_ylabel('Test Accuracy of {} model'.format(model_type.upper()))
        plt.legend()
        plt.tight_layout()
        plt.savefig('results/{}-{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                    args.filling_method, \
                                                     '_'.join(args.experiment_variable.lower().split()), \
                                                     model_type, \
                                                    'test_accuracy'))

if args.experiment_variable == 'Link Rate':
    fig, axs = plt.subplots(2)
    
    for idx, (epsilon, risk_means, risk_stds) in enumerate([(-1, plt_discrimination_risk_means, plt_discrimination_risk_stds)] \
                                                  + [(epsilon, plt_fair_discrimination_risk_means[epsilon], plt_fair_discrimination_risk_stds[epsilon]) for epsilon in args.epsilons if epsilon != 0]):
        if epsilon == -1:
            label = "Regular " + filling_method_name
        else:
            label = str(epsilon) + "-Fair " + filling_method_name
        
        new_x = []
        new_y = []
        new_yerr = []
        for idx_0, intra_link in enumerate(args.experiment_list_0):
            for idx_1, inter_link in enumerate(args.experiment_list_1):
                new_x.append(inter_link / intra_link)
                new_y.append(risk_means[idx_0, idx_1])
                new_yerr.append(risk_stds[idx_0, idx_1])

                
        order = np.argsort(new_x)
        new_x = np.array(new_x)[order]
        new_y = np.array(new_y)[order]
        new_yerr = np.array(new_yerr)[order]
        axs[0].errorbar(new_x, new_y, yerr=new_yerr, label=label)
        
    axs[0].set_xlabel(args.axis_1 + " / " + args.axis_0)
    axs[0].set_ylabel('Discrimination Risk')
    axs[0].legend()
    
    new_x = []
    new_y = []
    new_yerr = []
    for idx_0, intra_link in enumerate(args.experiment_list_0):
        for idx_1, inter_link in enumerate(args.experiment_list_1):
            new_x.append(inter_link / intra_link)
            new_y.append(plt_alpha_means[idx_0, idx_1])
            new_yerr.append(plt_alpha_stds[idx_0, idx_1])
    
    order = np.argsort(new_x)
    new_x = np.array(new_x)[order]
    new_y = np.array(new_y)[order]
    new_yerr = np.array(new_yerr)[order]
    axs[1].errorbar(new_x, new_y, yerr=new_yerr)
    axs[1].set_xlabel(args.axis_1 + " / " + args.axis_0)
    axs[1].set_ylabel('Max Alpha')
    
    plt.tight_layout()
    plt.savefig('results/{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                         args.filling_method, \
                                         '_'.join(args.experiment_variable.lower().split()), \
                                        'discrimination_risk'))

if args.experiment_variable == 'Link Rate':
    
    for model_type in args.models:
        plt.figure()
        ax = plt.axes()
        
        for idx, (epsilon, dp_means, dp_stds) in enumerate([(-1, plt_test_dp_means[model_type], plt_test_dp_stds[model_type])] \
                                                  + [(epsilon, plt_fair_test_dp_means[epsilon][model_type], plt_fair_test_dp_stds[epsilon][model_type]) for epsilon in args.epsilons]):
            if epsilon == -1:
                label = "Regular " + filling_method_name
            else:
                label = str(epsilon) + "-Fair " + filling_method_name
                
            new_x = []
            new_y = []
            new_yerr = []
            for idx_0, intra_link in enumerate(args.experiment_list_0):
                for idx_1, inter_link in enumerate(args.experiment_list_1):
                    new_x.append(inter_link / intra_link)
                    new_y.append(dp_means[idx_0, idx_1])
                    new_yerr.append(dp_stds[idx_0, idx_1])

            order = np.argsort(new_x)
            new_x = np.array(new_x)[order]
            new_y = np.array(new_y)[order]
            new_yerr = np.array(new_yerr)[order]
            ax.errorbar(new_x, new_y, yerr=new_yerr, label=label)

        ax.set_xlabel(args.axis_1 + " / " + args.axis_0)
        ax.set_ylabel('Test Demographic Parity of {} model'.format(model_type.upper()))
        plt.tight_layout()
        plt.legend()
        plt.savefig('results/{}-{}-{}-{}-{}.pdf'.format(args.dataset_name, \
                                                    args.filling_method, \
                                                     '_'.join(args.experiment_variable.lower().split()), \
                                                     model_type, \
                                                    'test_dp'))
